-
Notifications
You must be signed in to change notification settings - Fork 36
Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120 #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat : Add DP-SGD Transformer example using Flax NNX API | Issue #120 #126
Conversation
7cbfbb1 to
944df7c
Compare
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law- agreed to in writing, software |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fix typo
examples/dp_sgd_transformer_nnx.py
Outdated
| Returns: | ||
| The content of the downloaded file as a string. | ||
| """ | ||
| with urllib.request.urlopen(url) as response: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add timeout to prevent indefinite blocking
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's a good catch brother. i have now added a timeout and is definitely best practice to avoid hangs in CI/CD. I've updated download_data to include a 10-second timeout. I'm also moving the flax dependency into a proper requirements file as you suggested.
examples/dp_sgd_transformer_nnx.py
Outdated
| import urllib.request | ||
|
|
||
| from flax import nnx | ||
| from flax import nnx # pytype: disable=import-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this line is unusual
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so try adding in the requirements txt which is located in the docs folder
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or you can add it to the "dev" requirements in pyproject.toml
| from absl import app | ||
| from absl import flags | ||
| import flax.linen as nn | ||
| import flax.linen as nn # pytype: disable=import-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same here too
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's not, in the cicd checks there is no flax installing dependency to when the pytype check happens, the code fails. Hence, this line is important to pass all the cicd checks.
For a long term note, we can tell the @RamSaw or @ryan112358 to add flax installing for the cicd check for no further issue.
ryan112358
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great ,very clean - nice work! Left some comments
examples/dp_sgd_transformer_nnx.py
Outdated
| import urllib.request | ||
|
|
||
| from flax import nnx | ||
| from flax import nnx # pytype: disable=import-error |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The requirements.txt in docs folder is intended to only contain requirements needed for documentation. The ones listed in pyproject.toml are only those needed by the core library. Probably the best thing to do is add an additional requirements.txt to the examples/ directory that includes flax, and updates .github/workflows/ci.yml to install these.
| x: Input batch (single example or microbatch). | ||
| y: Target batch (single example or microbatch). | ||
| graphdef: The static graph definition of the NNX model. | ||
| other: Non-trainable state (e.g., RNG counts). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What else other than the rng counts is captured here? Is it possible to call this argument prng and have it typed as a jax.Array, then somehow wire it through to flax? I ask because when you call clipped_grad, if the loss function contains a prng key it needs special handling.
examples/dp_sgd_transformer_nnx.py
Outdated
| Returns: | ||
| The scalar loss value. | ||
| """ | ||
| m = nnx.merge(graphdef, params, other) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Give this a descriptive name like model
examples/dp_sgd_transformer_nnx.py
Outdated
| l2_clip_norm=CLIP_NORM, | ||
| batch_argnums=(1, 2), # x and y are batched | ||
| keep_batch_dim=False, # Process per-example | ||
| return_values=True # Return loss values for logging |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You might need to pass prng_argnum here as well to ensure the random key is handled appropriately. But it might require slight refactoring of your loss function
examples/dp_sgd_transformer_nnx.py
Outdated
| functools.partial(pure_loss_fn, graphdef=graphdef, other=other), | ||
| l2_clip_norm=CLIP_NORM, | ||
| batch_argnums=(1, 2), # x and y are batched | ||
| keep_batch_dim=False, # Process per-example |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Usually we want to keep this to the default (True), unless we're doing user-level DP. If you set this to True (or remove it), can you remove the line that adds an extra batch axis in pure_loss_fn?
examples/dp_sgd_transformer_nnx.py
Outdated
| grads, loss = grad_fn(params, x, y) | ||
|
|
||
| # Aggregate gradients (mean across batch) | ||
| mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grad_fn already aggregates gradients across the batch dimension, so I think this is a bug
| # Aggregate gradients (mean across batch) | ||
| mean_grads = jax.tree.map(lambda g: jnp.mean(g, axis=0), grads) | ||
|
|
||
| # Add Privacy Noise |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll leave it up to your discretion, but I think these inline comments can be removed.
| # Training loop | ||
| print(f"Training for {NUM_STEPS} steps...") | ||
| for step in range(NUM_STEPS): | ||
| batch = get_batch(data, BATCH_SIZE, CONTEXT_LENGTH) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In an ideal world this would use poisson sampling / jax_privacy.batch_selection. It's fine to leave a TODO for now and add it in a follow-up
examples/dp_sgd_transformer_nnx.py
Outdated
| ) | ||
|
|
||
| privatizer = noise_addition.gaussian_privatizer( | ||
| stddev=CLIP_NORM, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The stddev should be grad_fn.sensitiivty() * noise_multiplier. can you add NOISE_MULTIPLIER to the list of constants above?
1d03537 to
9eac33d
Compare
|
Hi @ryan112358 , I've pushed an update addressing all your feedback. Here is a summary of the changes I made:
✅ Verification: The script was verified for 10 steps locally, achieving a stable loss and passing a 10.00/10 pylint check. Remind me if new changes are required! |
|
#128 might fix the ci failures easy to debug |
That's an Good approach for moving current CICD to modular DAG architecture. It is good for improving DX. |
|
@debanganghosh08 , since now the new ci pipeline and new dependency flow has been introduced, so there will ci failures from now on. As you have added the one lib in examples/req...txt it will not considered from now on. Kindly first pull the lastest changes from upstream main, then delete the examples/req..txt file and add the deps to the pyproject.toml, you can see there is optional tab and a space for [examples], kindly add it there. Now a central optional deps are managed at the root pyproject.toml file |
…alse) per maintainer review
b6d6d66 to
d5a7943
Compare
Thanks for the heads-up and the clear guidance on the new dependency flow, @amyssnippet! I've just pushed an update aligning with the new modular CI. I pulled the latest upstream changes, migrated flax to the [project.optional-dependencies] section in pyproject.toml, and cleaned up the temporary requirements file. Everything should be in sync now! |
amyssnippet
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess check the files changed tab, there are still some files visible, kindly fix them all, i already left comments
| - name: Install example requirements | ||
| run: pip install -r examples/requirements.txt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this block of ci should not be here, it is unusual, it is not required
| examples = [ | ||
| "flax", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i have already created arrays to manage all optional dependencies, check it here https://github.com/google-deepmind/jax_privacy/blob/main/pyproject.toml
i have made deps in the prev task with ci, make sure you pulled the changes properly. including this file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess its still available here, which is not required
This PR introduces a comprehensive example of training a Transformer model with Differential Privacy using the new Flax NNX API. While JAX Privacy provides robust support for Linen and Haiku, this addition provides a template for users moving toward the functional-object paradigm of NNX.
Key Technical Implementations:
✔️ Exhaustive State Partitioning: Utilizes nnx.split(model, nnx.Param, ...) to strictly separate trainable parameters from non-trainable state (RNG counts, etc.), ensuring the JAX tracer maintains leaf parity across functional boundaries.
✔️ Rank-Normalized Loss: Implements a rank-injection strategy within the pure loss function to account for vmap dimension-stripping. By forcing a singleton batch dimension during the forward pass, the model correctly generates 4D causal masks required by the attention mechanism.
✔️ Privacy-Safe State Reconstruction: Uses an internal nnx.merge pattern to ensure that mutations to RNG states during training remain local to the functional trace, preventing TraceContextError regressions.
✅ Verification: The script was validated on the Tiny Shakespeare dataset for 20 steps, achieving stable convergence under DP constraints (Default: CLIP_NORM=1.0).
Screenshot of output attached 👇
